{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "597347c3-f5a0-4fd5-a840-e23957536b3e",
   "metadata": {},
   "source": [
    "# Text Embeddings with `sentence-transformers`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb9eff2b-f5df-4c9e-85a6-e39846749e75",
   "metadata": {},
   "source": [
    "#### We'll start off by installing some dependencies: `sentence-transformers` for the models and `milvus` for the vector database. Milvus is known for its scalability and wide adoption among organiziations, but we have an \"embedded\" version too!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e8973196-c967-40bf-9adf-e9ef2df55f59",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: sentence_transformers in /Users/zilliz/.pyenv/lib/python3.9/site-packages (2.2.2)\n",
      "Requirement already satisfied: torchvision in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (0.16.2)\n",
      "Requirement already satisfied: numpy in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (1.26.2)\n",
      "Requirement already satisfied: sentencepiece in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (0.1.99)\n",
      "Requirement already satisfied: scikit-learn in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (1.3.2)\n",
      "Requirement already satisfied: torch>=1.6.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (2.1.2)\n",
      "Requirement already satisfied: huggingface-hub>=0.4.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (0.19.4)\n",
      "Requirement already satisfied: nltk in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (3.8.1)\n",
      "Requirement already satisfied: tqdm in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (4.66.1)\n",
      "Requirement already satisfied: transformers<5.0.0,>=4.6.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (4.36.1)\n",
      "Requirement already satisfied: scipy in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (1.12.0)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (2023.10.0)\n",
      "Requirement already satisfied: requests in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (2.31.0)\n",
      "Requirement already satisfied: packaging>=20.9 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (23.2)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (4.8.0)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (6.0.1)\n",
      "Requirement already satisfied: filelock in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (3.13.1)\n",
      "Requirement already satisfied: sympy in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from torch>=1.6.0->sentence_transformers) (1.12)\n",
      "Requirement already satisfied: networkx in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from torch>=1.6.0->sentence_transformers) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from torch>=1.6.0->sentence_transformers) (3.1.2)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (0.4.1)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (2023.10.3)\n",
      "Requirement already satisfied: tokenizers<0.19,>=0.14 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (0.15.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from jinja2->torch>=1.6.0->sentence_transformers) (2.1.3)\n",
      "Requirement already satisfied: joblib in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from nltk->sentence_transformers) (1.3.2)\n",
      "Requirement already satisfied: click in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from nltk->sentence_transformers) (8.1.7)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (3.6)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (2.1.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (3.3.2)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (2023.11.17)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from scikit-learn->sentence_transformers) (3.2.0)\n",
      "Requirement already satisfied: mpmath>=0.19 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sympy->torch>=1.6.0->sentence_transformers) (1.3.0)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from torchvision->sentence_transformers) (10.1.0)\n",
      "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 23.3.2 is available.\n",
      "You should consider upgrading via the '/Users/zilliz/.pyenv/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n",
      "Requirement already satisfied: milvus in /Users/zilliz/.pyenv/lib/python3.9/site-packages (2.3.5)\n",
      "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 23.3.2 is available.\n",
      "You should consider upgrading via the '/Users/zilliz/.pyenv/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install -U sentence_transformers\n",
    "!pip install -U milvus"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ef63d9e-898c-4167-a3a1-4268540ade60",
   "metadata": {},
   "source": [
    "##### We'll go over the basics first: specifying a model and computing its embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "3162309c-c311-47e6-994f-b08817dbcb57",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "model = SentenceTransformer(\"llmrails/ember-v1\")\n",
    "model.max_seq_length = 1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "51182ae0-2834-41ff-b257-8178c83fff7d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1024,)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_0 = model.encode(\"Zilliz is an awesome vector database.\")\n",
    "embedding_0\n",
    "embedding_0.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7550b3a6-ff07-44cb-b757-fbe13a3266f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.9391, 0.5762, 0.5002, 0.3822, 0.3003]])\n"
     ]
    }
   ],
   "source": [
    "from sentence_transformers.util import cos_sim\n",
    "sentences = [\"Zilliz is a vector data store that is amazing.\",\n",
    "             \"Unstructured data can be semantically represented with embeddings.\",\n",
    "             \"Singular value decomposition factorizes the input matrix into three other matrices.\",\n",
    "             \"My favorite chess opening is the King's Gambit.\",\n",
    "             \"It doesn't matter if a cat is black or white, so long as it catches mice.\"]\n",
    "embeddings = model.encode(sentences)\n",
    "print(cos_sim(embedding_0, embeddings))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "72e43c27-7c1b-4065-ae2c-de8911aac593",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cos_sim(model.encode(\"I like green eggs and ham.\"), model.encode(\"I like green eggs and ham.\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2868a7f8-4fcd-4479-8475-a21fe94ce5d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.8867]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cos_sim(model.encode(\"Let's eat, Chris.\"), model.encode(\"Let's eat Chris!\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "080e9943-9cda-4019-84cf-bc1386915ee2",
   "metadata": {},
   "source": [
    "#### Now let's check out how to fine-tune our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5c8c019a-3dfe-4528-aeea-f17ecdcef745",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import InputExample\n",
    "train_examples = [\n",
    "    InputExample(texts=[\"Give me a quote on pragmatism.\", \"Whether the cat is black or white doesn't matter, so long as it catches mice.\"], label=1.0),\n",
    "    InputExample(texts=[\"Y Combinator's 7th birthday was March 11.\", \"As usual we were so busy we didn't notice till a few days after.\"], label=1.0)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0b2376ea-1a5a-4db4-a5ac-5f9d9f8754ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import losses\n",
    "from torch.utils.data import DataLoader\n",
    "train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=1)\n",
    "train_loss = losses.CosineSimilarityLoss(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "67bdadea-f523-48b0-968c-ae2c1bbf6ea9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b1076953670043509a1699051cfd043b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "90bbc8a403df4a35908fc79d46bc5a06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99484456-6861-4b3e-832a-bc3b03eaf01e",
   "metadata": {},
   "source": [
    "#### How about inserting into a vector database?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2be0b673-ce96-4145-b85c-44d9be862ad1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    }
   ],
   "source": [
    "from milvus import default_server\n",
    "default_server.start()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "39f73880-4a96-4b53-b1db-9f9627e2d09b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pymilvus import connections\n",
    "connections.connect(alias=\"default\",\n",
    "                    host=\"127.0.0.1\", \n",
    "                    port=default_server.listen_port,\n",
    "                    show_startup_banner=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ff37f8a9-a39f-4bd9-82fa-04ec1bb34219",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pymilvus import utility, FieldSchema, DataType, Collection, CollectionSchema\n",
    "\n",
    "if utility.has_collection(\"default\"):\n",
    "    utility.drop_collection(\"default\")\n",
    "\n",
    "fields = [\n",
    "    FieldSchema(name=\"id\", dtype=DataType.INT64, is_primary=True, auto_id=True),\n",
    "    FieldSchema(name=\"embedding\", dtype=DataType.FLOAT_VECTOR, dim=1024)\n",
    "]\n",
    "schema = CollectionSchema(fields=fields)\n",
    "\n",
    "collection = Collection(name=\"default\", schema=schema)\n",
    "\n",
    "index_params = {\n",
    "    \"index_type\": \"HNSW\",\n",
    "    \"metric_type\": \"COSINE\",\n",
    "    \"params\": {\"M\": 64, \"ef\": 32, \"efConstruction\": 32}\n",
    "}\n",
    "collection.create_index(field_name=\"embedding\", index_params=index_params)\n",
    "collection.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "c24f8be4-3fd1-4095-bc29-99ef4e5c8957",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(insert count: 5, delete count: 0, upsert count: 0, timestamp: 447271173950799876, success count: 5, err count: 0)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "collection.insert([{\"embedding\": e} for e in embeddings])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "44af4639-928a-4bf1-b62b-1ecb19cff83b",
   "metadata": {},
   "outputs": [],
   "source": [
    "default_server.stop()\n",
    "default_server.cleanup()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfbef666-0836-46a6-90f8-c4be9f441697",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
